{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "2d579aba-7439-42a6-aae4-cf1983095dee",
   "metadata": {},
   "source": [
    "# Example 08 - Optuna with Wideband and YOLO\n",
    "This notebook showcases Optuna hyperparameter tuning a YOLOv8 model with Wideband.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "40a026bd-f096-47f3-a262-48ab5defe23e",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4bff6d5-4b2d-4db2-97a0-f45843e7cc60",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# Packages for Optuna\n",
    "from torchsig.utils.optuna.tuner import YoloOptunaOptimizer\n",
    "import copy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ab949825-7c48-44f2-852d-56567f5952e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Packages Imports for Training\n",
    "from ultralytics import YOLO\n",
    "from torchsig.utils.yolo_train import Yolo_Trainer\n",
    "from torchsig.datasets.datamodules import WidebandDataModule\n",
    "from torchsig.transforms.transforms import Compose, Spectrogram, Normalize, SpectrogramImage\n",
    "from torchsig.transforms.target_transforms import DescToBBoxFamilyDict\n",
    "from torchsig.datasets.signal_classes import torchsig_signals\n",
    "import torch\n",
    "import yaml\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "efea64b7-1d22-41bd-81cc-5c39a74c3bdd",
   "metadata": {},
   "source": [
    "-----------------------------\n",
    "## Check or Generate the Wideband Dataset\n",
    "To generate the Wideband dataset, several parameters are given to the imported `WidebandDataModule` class. These paramters are:\n",
    "- `root` ~ A string to specify the root directory of where to generate and/or read an existing Wideband dataset\n",
    "- `train` ~ A boolean to specify if the Wideband dataset should be the training (True) or validation (False) sets\n",
    "- `qa` - A boolean to specify whether to generate a small subset of Wideband (True), or the full dataset (False), default is True\n",
    "- `impaired` ~ A boolean to specify if the Wideband dataset should be the clean version or the impaired version\n",
    "- `transform` ~ Optionally, pass in any data transforms here if the dataset will be used in an ML training pipeline. Note: these transforms are not called during the dataset generation. The static saved dataset will always be in IQ format. The transform is only called when retrieving data examples.\n",
    "- `target_transform` ~ Optionally, pass in any target transforms here if the dataset will be used in an ML training pipeline. Note: these target transforms are not called during the dataset generation. The static saved dataset will always be saved as tuples in the LMDB dataset. The target transform is only called when retrieving data examples.\n",
    "\n",
    "A combination of the `train` and the `impaired` booleans determines which of the four (4) distinct Wideband datasets will be instantiated:\n",
    "| `impaired` | `qa` | Result |\n",
    "| ---------- | ---- | ------- |\n",
    "| `False` | `False` | Clean datasets of train=250k examples and val=25k examples |\n",
    "| `False` | `True` | Clean datasets of train=250 examples and val=250 examples |\n",
    "| `True` | `False` | Impaired datasets of train=250k examples and val=25k examples |\n",
    "| `True` | `True` | Impaired datasets of train=250 examples and val=250 examples |\n",
    "\n",
    "The final option of the impaired validation set is the dataset to be used when reporting any results with the official Wideband dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ca6e8bfb-92f3-4615-afac-7568350a2461",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate Wideband DataModule\n",
    "root = \"./datasets/wideband\"\n",
    "impaired = True\n",
    "qa = True\n",
    "fft_size = 512\n",
    "class_list = torchsig_signals.class_list\n",
    "num_classes = len(class_list)\n",
    "batch_size = 1\n",
    "\n",
    "transform = Compose([    \n",
    "])\n",
    "\n",
    "target_transform = Compose([\n",
    "    DescToBBoxFamilyDict()\n",
    "])\n",
    "\n",
    "datamodule = WidebandDataModule(\n",
    "    root=root,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=transform,\n",
    "    target_transform=target_transform,\n",
    "    batch_size=batch_size\n",
    ")\n",
    "\n",
    "datamodule.prepare_data()\n",
    "datamodule.setup(\"fit\")\n",
    "\n",
    "wideband_train = datamodule.train\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(wideband_train))\n",
    "data, label = wideband_train[idx]\n",
    "print(\"Dataset length: {}\".format(len(wideband_train)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "print(\"Label: {}\".format(label))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2938efda-7b7f-42e9-8bcf-022ebcaf2d32",
   "metadata": {},
   "source": [
    "## Prepare YOLO trainer and Model\n",
    "Next, the datasets are rewritten to disk that is Ultralytics YOLO compatible. See [Ultralytics: Train Custom Data - Organize Directories](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#23-organize-directories) to learn more. \n",
    "\n",
    "Additionally, create a yaml file for dataset configuration. See [Ultralytics: Train Custom Data - Create dataset.yaml](https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#21-create-datasetyaml)\n",
    "\n",
    "Download desired YOLO model from [Ultralytics Models](https://docs.ultralytics.com/models/). We will use YOLOv8, specifically `yolov8n.pt`\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f8f5ad86-bb23-4396-8cd3-3b1c1c9a32b2",
   "metadata": {},
   "source": [
    "### Explanation of the `overrides` Dictionary\n",
    "\n",
    "The `overrides` dictionary is used to customize the settings for the Ultralytics YOLO trainer by specifying specific values that override the default configurations. The dictionary is imported from `wbdata.yaml`. However, you can customize in the notebook. \n",
    "\n",
    "Example:\n",
    "\n",
    "```python\n",
    "overrides = {'model': 'yolov8n.pt', 'epochs': 100, 'data': '08_example.yaml', 'device': 0, 'imgsz': 512, 'single_cls': True}\n",
    "```\n",
    "A .yaml is necessary for training. Look at `08_yolo_optuna.yaml` in the examples directory. It will contain the path to your torchsig data.\n",
    "\n",
    "\n",
    "### Dataset Location Warning\n",
    "\n",
    "There must exist a datasets directory at `/path/to/torchsig/datasets`.\n",
    "\n",
    "This example assumes that you have generated `train` and `val` lmdb wideband datasets at `./datasets/wideband/`\n",
    "\n",
    "You can also specify an absolute path to your dataset in `08_yolo_optuna.yaml`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc6137ed-44f5-4dd3-a16c-0267a84bfe55",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define dataset variables for yaml file\n",
    "config_name = \"08_yolo_optuna.yaml\"\n",
    "classes = {v: k for v, k in enumerate(class_list)}\n",
    "yolo_root = \"./wideband/\" # train/val images (relative to './datasets``\n",
    "\n",
    "# define overrides\n",
    "overrides = dict(\n",
    "    model = \"yolov8n.pt\",\n",
    "    project = \"yolo\",\n",
    "    name = \"08_example\",\n",
    "    epochs = 10,\n",
    "    imgsz = 512,\n",
    "    data = config_name,\n",
    "    device = 0 if torch.cuda.is_available() else \"cpu\",\n",
    "    single_cls = True,\n",
    "    batch = 32,\n",
    "    workers = 8\n",
    "\n",
    ")\n",
    "\n",
    "# create yaml file for trainer\n",
    "yolo_config = dict(\n",
    "    overrides = overrides,\n",
    "    train = yolo_root,\n",
    "    val = yolo_root,\n",
    "    nc = num_classes,\n",
    "    names = classes\n",
    ")\n",
    "\n",
    "with open(config_name, 'w+') as file:\n",
    "    yaml.dump(yolo_config, file, default_flow_style=False)\n",
    "print(f\"Creating experiment -> {overrides['name']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8f7ccf16",
   "metadata": {},
   "source": [
    "## Run Optuna\n",
    "Now we can start using optuna to tune the following hyperparameters in YOLO:\n",
    "- `lr0` - Initial learning rate.\n",
    "- `cos_lr` - Toggle cosine learning rate scheduler.\n",
    "- `optimizer` - Choose optimizer (SGD, Adam, AdamW)\n",
    "- `freeze` - Freeze the first N layers of the model.\n",
    "- `imgsz` - Target image size for training.\n",
    "\n",
    "Within the optuna optimizer,`n_trials` determines how many trials to run, while `epochs` is how many epochs are run per trial.\n",
    "\n",
    "See [Ultralytics Train Settings](https://docs.ultralytics.com/usage/cfg/#train-settings) for more hyperparameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d7017ba",
   "metadata": {},
   "outputs": [],
   "source": [
    "opt = YoloOptunaOptimizer(overrides, n_trials=2, epochs=1)\n",
    "study, best_params = opt.run_optimization()\n",
    "\n",
    "overrides_optimized = opt.get_optimized_overrides()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22fa62fc-01d9-46b7-9cd2-38643f54d1de",
   "metadata": {},
   "source": [
    "## Train YOLO with Optimized Hyperparameters\n",
    "Train YOLO. See [Ultralytics Train](https://docs.ultralytics.com/modes/train/#train-settings) for training hyperparameter options.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ad7e5f57-3b5b-4074-b87c-6e18c7b973af",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "trainer = Yolo_Trainer(overrides=overrides_optimized)\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ad0c459-2367-4f7d-89ae-5f4cfd8f545f",
   "metadata": {},
   "source": [
    "## Evaluation\n",
    "Check model performance from training. From here, you can use the trained model to test on prepared data (numpy image arrays of spectrograms)\n",
    "\n",
    "Will load example from Torchsig\n",
    "\n",
    "model_path is path to best.pt from your training session. Path is printed at the end of training.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c45fa7bf-fa7d-4896-9022-4e608d93e5a4",
   "metadata": {},
   "source": [
    "## Generate and Instantiate WBSig53 Test Dataset\n",
    "After generating the WBSig53 dataset (see `03_example_Wideband_dataset.ipynb`), we can instantiate it with the needed transforms. Change `root` to test dataset path.\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "694c33f1-d718-40ae-861d-26c3431a6f41",
   "metadata": {},
   "outputs": [],
   "source": [
    "test_path = './datasets/wideband_test' #Should differ from your training dataset\n",
    "\n",
    "transform = Compose([\n",
    "    Spectrogram(nperseg=512, noverlap=0, nfft=512, mode='psd'),\n",
    "    Normalize(norm=np.inf, flatten=True),\n",
    "    SpectrogramImage(), \n",
    "    ])\n",
    "\n",
    "test_data = WidebandDataModule(\n",
    "    root=test_path,\n",
    "    impaired=impaired,\n",
    "    qa=qa,\n",
    "    fft_size=fft_size,\n",
    "    num_classes=num_classes,\n",
    "    transform=transform,\n",
    "    target_transform=None,\n",
    "    batch_size=batch_size\n",
    ")\n",
    "\n",
    "test_data.prepare_data()\n",
    "test_data.setup(\"fit\")\n",
    "\n",
    "wideband_test = test_data.train\n",
    "\n",
    "# Retrieve a sample and print out information\n",
    "idx = np.random.randint(len(wideband_test))\n",
    "data, label = wideband_test[idx]\n",
    "print(\"Dataset length: {}\".format(len(wideband_test)))\n",
    "print(\"Data shape: {}\".format(data.shape))\n",
    "\n",
    "samples = []\n",
    "labels = []\n",
    "for i in range(10):\n",
    "    idx = np.random.randint(len(wideband_test))\n",
    "    sample, label = wideband_test[idx]\n",
    "    lb = [l['class_name'] for l in label]\n",
    "    samples.append(sample)\n",
    "    labels.append(lb)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de9947fc-9636-415c-af50-26d7e5d95953",
   "metadata": {},
   "source": [
    "### Load model \n",
    "The model path is printed after training. Use the best.pt weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "76e73a9c",
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e2bb8735-ce6c-482b-9675-302f6848b0ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = YOLO(trainer.best)\n",
    "# Inference will be saved to path printed after predict. \n",
    "results = model.predict(samples, save=True, imgsz=512, conf=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "107d38d4-36af-4d0b-8bd9-926228b3e4ee",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Process results list\n",
    "for y, result in enumerate(results):\n",
    "    boxes = result.boxes  # Boxes object for bounding box outputs\n",
    "    probs = result.probs  # Probs object for classification outputs\n",
    "    print(f'Actual Labels -> {labels[y]}')\n",
    "    result.show()  # display to screen"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
